import pandas as pd
import numpy as np
import pickle
from pre_handle import PreHandle
from fileRW import ReadFile
from strategies import Strategy
from log_recording import LogRecording
from train_model import TrainModel


class BackTest:
    file_path = "./data.pkl"
    data_in = dict()
    data_avg = dict()

    def __init__(self, file_path="./data.pkl", start_date='2010-01-04', end_date='2019-12-31'):
        '''
        构造函数，主要完成回测时间段的处理和数据缺失值的预处理
        :param file_path: 原始数据路径
        :param start_date: 回测开始时间
        :param end_date: 回测结束时间
        '''
        self.file_path = file_path
        rf = ReadFile(file_path=file_path)
        data1 = rf.read_pickle()

        # 处理回测时间
        days = sorted(data1.keys())  # 日期排序后的列表，存储所有日期

        # 指定日期的处理代码，对字典取其子字典
        start_index = 0
        end_index = len(days) - 1
        for i in range(len(days)):
            if days[i] >= start_date:
                start_index = i
                break
        for i in range(start_index, len(days)):
            if days[i] <= end_date:
                end_index = i
        key_list = days[start_index: end_index + 1]  # 更新为指定时间段
        data1 = {key: value for key, value in data1.items() if key in key_list}

        data_copy = data1.copy()
        PreHandle().prehandle(data_copy)
        self.data_in = data_copy

        data_copy_avg = data1.copy()
        PreHandle().prehandle_db_avg_stgy(data_copy_avg)
        self.data_avg = data_copy_avg

    def get_income(self, start, end, dict1, share_series, list_len):
        '''
        获得一个持仓周期结束的时候涨幅涨了多少
        '''
        sub = (dict1[end]['close'] - dict1[start]['open']) / dict1[start]['open']
        sub_series = pd.Series(sub, index=sub.index)
        sum = 0

        for i in range(list_len):
            sum += sub_series[share_series[i]]

        return sum / list_len

    def back_trader(self, strategy_name='涨幅策略'):
        '''
        回测函数
        策略的持仓与调仓方式不同，计算收益率的方法也不同
        :param strategy_name: 策略名称，目前支持“涨幅策略”和“双均线策略”
        :return: 每日相较于回测开始日期的收益率
        '''

        key_list = list(self.data_in.keys())
        res_dict1 = {}  # 记录每日持仓
        listt = len(key_list) * [0.]
        res_series = pd.Series(listt, index=key_list)   # 记录每日相对于回测第一天的总收益率

        # 两种策略的调仓标准不同、收益计算方式也不同，所以需要用if判断。
        if strategy_name == '涨幅策略':

            dict1 = self.data_in

            # 涨幅策略所需参数的初始化
            pre_day_idx = 3
            open_cycle = 22     # 持仓周期
            position_num = 10   # 持有股票数
            hier_pct = 0.       # 开始的时候收益率为0
            shares_list = []    # 当前持有的股票名称的列表

            for _ in range(position_num):
                shares_list.append("")
            share_series = pd.Series(shares_list)  # 维护当前持仓，初始值为空的持有的10只股票的名称的一个series

            i = 0   # 回测日期索引
            while i < len(key_list):    # 开始回测
                pre_day_idx, hier_pct, share_series = Strategy().increase_strategy(i, pre_day_idx, key_list, dict1,
                                                                                   share_series, position_num,
                                                                                   open_cycle, hier_pct)

                if i <= 2:
                    i += 1
                    continue

                new_income = self.get_income(key_list[pre_day_idx], key_list[i], dict1, share_series,
                                             position_num)  # 获得到本周期到现在为止的涨幅
                sub = (dict1[key_list[i]]['close'] - dict1[key_list[pre_day_idx]]['open']) / \
                      dict1[key_list[pre_day_idx]]['open']
                sub_series = pd.Series(sub, index=sub.index)
                res_series[key_list[i]] = (1 + hier_pct) * (new_income + 1) - 1

                # 需要新建一个series，键为shares_series的值也就是股票名称，值为sub_series中的股票的涨幅
                new_series = pd.Series(np.zeros(position_num), index=share_series.values)
                for name in new_series.index:
                    new_series[name] = sub_series[name]

                res_dict1[key_list[i]] = new_series  # 用于将每天的持仓存入res_dict1中

                i += 1

        elif strategy_name == '双均线策略':

            dict1 = self.data_avg

            # 双均线策略所需参数的初始化
            longer_timelen = 10     # 长期移动平均周期
            shorter_timelen = 5     # 短期移动平均周期
            position_state = pd.Series([False] * 21, index=dict1[key_list[0]].index)  # 表示当前个股票的持仓状态，初始时均为0表示未持仓
            income_now = pd.Series([0.] * 21, index=dict1[key_list[0]].index)  # 表示到昨天为止各只股票的收益率
            i = longer_timelen  # 日期索引

            while i < len(key_list):    # 开始回测
                short_series, longer_series = Strategy().double_average_strategy(i, dict1, key_list, shorter_timelen, longer_timelen, position_state)

                """
                收益的话，如果今天没有持有,那么就是和昨天一样，如果今天持有，那么就是要和昨天收盘价进行计算，然后和之前的进行累积
                （先计算第一个数组这里状态数组对应为True取第二个，否则取第一个）
                然后还要对每只股票的持仓状态进行更新：
                    首先使用两个数组进行异或，一个是5日均值大于10日均值，另外一个是本身的持仓状态
                    然后异或的结果如果是0说明不需要更新，1需要更新
                    所以只需要将当前持仓状态和这个异或结果进行异或返回给当前持仓状态即可
                """
                left_st = pd.Series([True] * 21, position_state.index) ^ position_state  # 第一个的状态数组
                today_income_rate = (dict1[key_list[i]]['close'] - dict1[key_list[i - 1]]['close']) / \
                                    dict1[key_list[i - 1]][
                                        'close']
                if_income_now = (income_now + 1) * (1 + today_income_rate) - 1  # 如果今天持有，总的收益数组
                income_now = left_st * income_now + position_state * if_income_now

                state1 = (short_series > longer_series) ^ position_state
                position_state ^= state1

                res_dict1[key_list[i]] = income_now
                res_series[key_list[i]] = income_now.mean()

                i += 1

        LogRecording.log_to_json(res_dict1, write_path="log.json")  # 记录日志

        return res_series


if __name__ == '__main__':

    bt = BackTest(file_path="./data.pkl", start_date='2013-04-12', end_date='2018-05-20')
    pd.set_option('display.max_columns', None)

    print(bt.back_trader(strategy_name='涨幅策略'))

